From fb875e9fb8933310829dc2107863882a3f8f4408 Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 28 Oct 2025 12:12:54 +0100 Subject: [PATCH 1/2] Add SimpleProxy to setitem --- src/blosc2/ndarray.py | 26 +++++++++++++------------- src/blosc2/proxy.py | 2 +- tests/ndarray/test_setitem.py | 9 +++++++++ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/blosc2/ndarray.py b/src/blosc2/ndarray.py index 2cb39c1c..6b3ff245 100644 --- a/src/blosc2/ndarray.py +++ b/src/blosc2/ndarray.py @@ -4261,6 +4261,9 @@ def __setitem__( key_, mask = process_key(key, self.shape) # internally handles key an integer if hasattr(value, "shape") and value.shape == (): value = value.item() + value = ( + value if np.isscalar(value) else blosc2.as_simpleproxy(value) + ) # convert to SimpleProxy for e.g. JAX, Tensorflow, PyTorch if builtins.any(isinstance(k, (list, np.ndarray)) for k in key_): # fancy indexing _slice = ndindex.ndindex(key_).expand( @@ -4284,20 +4287,17 @@ def __setitem__( return self._get_set_nonunit_steps((start, stop, step, mask), value=value) shape = [sp - st for sp, st in zip(stop, start, strict=False)] - if isinstance(value, NDArray): - value = value[...] # convert to numpy - if np.isscalar(value): + if isinstance(value, blosc2.Operand): # handles SimpleProxy, NDArray, LazyExpr etc. + value = value[()] # convert to numpy + if np.isscalar(value) or value.shape == (): value = np.full(shape, value, dtype=self.dtype) - elif isinstance(value, np.ndarray): # handles decompressed NDArray too - if value.dtype != self.dtype: - try: - value = value.astype(self.dtype) - except ComplexWarning: - # numexpr type inference can lead to unnecessary type promotions - # when using complex functions (e.g. conj) with real arrays - value = value.real.astype(self.dtype) - if value.shape == (): - value = np.full(shape, value, dtype=self.dtype) + if isinstance(value, np.ndarray) and value.dtype != self.dtype: # handles decompressed NDArray too + try: + value = value.astype(self.dtype) + except ComplexWarning: + # numexpr type inference can lead to unnecessary type promotions + # when using complex functions (e.g. conj) with real arrays + value = value.real.astype(self.dtype) return super().set_slice((start, stop), value) diff --git a/src/blosc2/proxy.py b/src/blosc2/proxy.py index f829e503..15e01d5d 100644 --- a/src/blosc2/proxy.py +++ b/src/blosc2/proxy.py @@ -612,7 +612,7 @@ class SimpleProxy(blosc2.Operand): def __init__(self, src, chunks: tuple | None = None, blocks: tuple | None = None): if not hasattr(src, "shape") or not hasattr(src, "dtype"): - # If the source is not a NumPy array, convert it to one + # If the source is not an array, convert it to NumPy src = np.asarray(src) if not hasattr(src, "__getitem__"): raise TypeError("The source must have a __getitem__ method") diff --git a/tests/ndarray/test_setitem.py b/tests/ndarray/test_setitem.py index d54a017d..a2145b90 100644 --- a/tests/ndarray/test_setitem.py +++ b/tests/ndarray/test_setitem.py @@ -8,6 +8,7 @@ import numpy as np import pytest +import torch import blosc2 @@ -44,6 +45,14 @@ def test_setitem(shape, chunks, blocks, slices, dtype): nparray[slices] = val np.testing.assert_almost_equal(a[...], nparray) + # Object called via SimpleProxy + slice_shape = a[slices].shape + dtype_ = {np.float32: torch.float32, np.int32: torch.int32, np.float64: torch.float64}[dtype] + val = torch.ones(slice_shape, dtype=dtype_) + a[slices] = val + nparray[slices] = val + np.testing.assert_almost_equal(a[...], nparray) + # blosc2.NDArray if np.prod(slice_shape) == 1 or len(slice_shape) != len(blocks): chunks = None From 3ddd9a330b2a3bf22bd4ccd512e43e98891907bd Mon Sep 17 00:00:00 2001 From: lshaw8317 Date: Tue, 28 Oct 2025 12:58:51 +0100 Subject: [PATCH 2/2] Minor simplification --- src/blosc2/ndarray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/blosc2/ndarray.py b/src/blosc2/ndarray.py index 6b3ff245..c3142b44 100644 --- a/src/blosc2/ndarray.py +++ b/src/blosc2/ndarray.py @@ -4291,7 +4291,7 @@ def __setitem__( value = value[()] # convert to numpy if np.isscalar(value) or value.shape == (): value = np.full(shape, value, dtype=self.dtype) - if isinstance(value, np.ndarray) and value.dtype != self.dtype: # handles decompressed NDArray too + if value.dtype != self.dtype: # handles decompressed NDArray too try: value = value.astype(self.dtype) except ComplexWarning: