From 08ce7d97af32eb4bfd5a48e260f5c8c9e4b9b784 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 23 May 2026 07:53:50 +0200 Subject: [PATCH 1/3] Fix JAX assignment index normalization --- src/pyrecest/_backend/jax/__init__.py | 125 ++++++++++++++++++++------ 1 file changed, 96 insertions(+), 29 deletions(-) diff --git a/src/pyrecest/_backend/jax/__init__.py b/src/pyrecest/_backend/jax/__init__.py index 834e36756..3c662f532 100644 --- a/src/pyrecest/_backend/jax/__init__.py +++ b/src/pyrecest/_backend/jax/__init__.py @@ -4,6 +4,7 @@ """ import builtins as _builtins +import numbers as _numbers import jax.numpy as _jnp from jax import vmap @@ -242,6 +243,77 @@ def take( ) +def _is_boolean_index(indices): + if isinstance(indices, (bool, _jnp.bool_)): + return True + if isinstance(indices, (list, tuple)): + return bool(indices) and _is_boolean_index(indices[0]) + if isinstance(indices, _jnp.ndarray): + return indices.dtype in (_jnp.bool_, _jnp.uint8) + return False + + +def _is_iterable_index(indices): + if isinstance(indices, (list, tuple)): + return True + if isinstance(indices, _jnp.ndarray): + return indices.ndim > 0 + return False + + +def _is_scalar_index(index): + return isinstance(index, _numbers.Integral) or ( + isinstance(index, _jnp.ndarray) and index.ndim == 0 + ) + + +def _assignment_value_length(values): + if isinstance(values, (list, tuple)): + return len(values) + if isinstance(values, _jnp.ndarray) and values.ndim > 0: + return values.shape[0] + return 1 + + +def _normalize_assignment_index(indices, ndim_x, axis=0): + if _is_boolean_index(indices): + return _jnp.asarray(indices), False, None + + use_vectorization = _is_iterable_index(indices) and len(indices) < ndim_x + zip_indices = ( + _is_iterable_index(indices) + and len(indices) > 0 + and _is_iterable_index(indices[0]) + ) + + if use_vectorization: + normalized = tuple(list(indices[:axis]) + [slice(None)] + list(indices[axis:])) + return normalized, True, None + + if zip_indices: + normalized = tuple(_jnp.asarray(index_axis) for index_axis in zip(*indices)) + return normalized, False, len(indices) + + if isinstance(indices, list): + return _jnp.asarray(indices), False, len(indices) + if isinstance(indices, _jnp.ndarray) and indices.ndim > 0: + return indices, False, indices.shape[0] + if isinstance(indices, tuple): + if all(_is_scalar_index(index) for index in indices): + return indices, False, 1 + if indices and _is_iterable_index(indices[0]): + return indices, False, len(indices[0]) + return indices, False, 1 + + +def _validate_assignment_value_count(values, *, use_vectorization, len_indices): + if use_vectorization or len_indices is None: + return + len_values = _assignment_value_length(values) + if len_values > 1 and len_values != len_indices: + raise ValueError("Either one value or as many values as indices required") + + def assignment(x, values, indices, axis=0): """ Assign values at given indices of an array using JAX. @@ -265,20 +337,18 @@ def assignment(x, values, indices, axis=0): x_new : JAX array, shape=[dim] Copy of x with the values assigned at the given indices. """ - # Ensure indices and values are iterable - if isinstance(indices, (int, tuple)): - indices = [indices] - if not isinstance(values, list): - values = [values] * len(indices) - - # Check if we need to raise errors for mismatch in values and indices lengths - if len(values) != 1 and len(values) != len(indices): - raise ValueError("Either one value or as many values as indices required") - - # Handling assignment with index update - x_new = x.at[indices].set(values) - - return x_new + x = _jnp.asarray(x) + normalized_indices, use_vectorization, len_indices = _normalize_assignment_index( + indices, + x.ndim, + axis=axis, + ) + _validate_assignment_value_count( + values, + use_vectorization=use_vectorization, + len_indices=len_indices, + ) + return x.at[normalized_indices].set(values) def assignment_by_sum(x, values, indices, axis=0): @@ -308,21 +378,18 @@ def assignment_by_sum(x, values, indices, axis=0): If a single value is provided, it is added at all the indices. If a list is given, it must have the same length as indices. """ - # Ensure indices and values are iterable - if isinstance(indices, (int, tuple)): - indices = [indices] - if not isinstance(values, list): - values = [values] * len(indices) - - # Check if the number of values matches the number of indices, or there's exactly one value - if len(values) != 1 and len(values) != len(indices): - raise ValueError("Either one value or as many values as indices required") - - # Handling addition with index update - for idx, val in zip(indices, values): - x = x.at[idx].add(val) - - return x + x = _jnp.asarray(x) + normalized_indices, use_vectorization, len_indices = _normalize_assignment_index( + indices, + x.ndim, + axis=axis, + ) + _validate_assignment_value_count( + values, + use_vectorization=use_vectorization, + len_indices=len_indices, + ) + return x.at[normalized_indices].add(values) def array_from_sparse(indices, data, target_shape): From bdb9a5d71e066f03990f2819a2d15330d7a8a820 Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 23 May 2026 07:59:06 +0200 Subject: [PATCH 2/3] Add JAX assignment indexing regression tests --- .../test_jax_assignment_contract.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 tests/backend_support/test_jax_assignment_contract.py diff --git a/tests/backend_support/test_jax_assignment_contract.py b/tests/backend_support/test_jax_assignment_contract.py new file mode 100644 index 000000000..23bff82d6 --- /dev/null +++ b/tests/backend_support/test_jax_assignment_contract.py @@ -0,0 +1,59 @@ +"""Regression tests for JAX backend assignment helper indexing.""" + +from __future__ import annotations + +import importlib.util + +import pytest + +from tests.support.backend_runner import run_backend_code + + +@pytest.mark.backend_portable +def test_jax_assignment_accepts_numpy_style_advanced_indices(): + if importlib.util.find_spec("jax") is None: + pytest.skip("JAX is not installed") + + result = run_backend_code( + "jax", + """ +import pyrecest.backend as backend + +x = backend.zeros((3, 3)) +assigned = backend.assignment(x, [5.0, 7.0], [(0, 1), (1, 2)]) +summed = backend.assignment_by_sum(backend.ones((3, 3)), [5.0, 7.0], [(0, 1), (1, 2)]) + +assert backend.to_numpy(assigned).tolist() == [[0.0, 5.0, 0.0], [0.0, 0.0, 7.0], [0.0, 0.0, 0.0]] +assert backend.to_numpy(summed).tolist() == [[1.0, 6.0, 1.0], [1.0, 1.0, 8.0], [1.0, 1.0, 1.0]] +""", + ) + + assert result.returncode == 0, result.stderr + + +@pytest.mark.backend_portable +def test_jax_assignment_accepts_list_and_boolean_indices(): + if importlib.util.find_spec("jax") is None: + pytest.skip("JAX is not installed") + + result = run_backend_code( + "jax", + """ +import pyrecest.backend as backend + +vector = backend.zeros(3) +by_list = backend.assignment(vector, [4.0, 5.0], [0, 2]) + +matrix = backend.zeros((3, 3)) +by_mask = backend.assignment( + matrix, + [1.0, 2.0], + [[True, False, False], [False, True, False], [False, False, False]], +) + +assert backend.to_numpy(by_list).tolist() == [4.0, 0.0, 5.0] +assert backend.to_numpy(by_mask).tolist() == [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]] +""", + ) + + assert result.returncode == 0, result.stderr From c99832627f9572e6f9e181a999064fc962e96d0e Mon Sep 17 00:00:00 2001 From: Florian Pfaff <6773539+FlorianPfaff@users.noreply.github.com> Date: Sat, 23 May 2026 09:05:02 +0200 Subject: [PATCH 3/3] Use built-in all for JAX assignment index checks --- src/pyrecest/_backend/jax/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyrecest/_backend/jax/__init__.py b/src/pyrecest/_backend/jax/__init__.py index 3c662f532..4a7093a9e 100644 --- a/src/pyrecest/_backend/jax/__init__.py +++ b/src/pyrecest/_backend/jax/__init__.py @@ -299,7 +299,7 @@ def _normalize_assignment_index(indices, ndim_x, axis=0): if isinstance(indices, _jnp.ndarray) and indices.ndim > 0: return indices, False, indices.shape[0] if isinstance(indices, tuple): - if all(_is_scalar_index(index) for index in indices): + if _builtins.all(_is_scalar_index(index) for index in indices): return indices, False, 1 if indices and _is_iterable_index(indices[0]): return indices, False, len(indices[0]) @@ -588,4 +588,4 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low): matrix = matrix.at[..., i, i].set(diag) matrix = matrix.at[..., j, k].set(tri_upp) matrix = matrix.at[..., k, j].set(tri_low) - return matrix + return matrix \ No newline at end of file