Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 97 additions & 30 deletions src/pyrecest/_backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import builtins as _builtins
import numbers as _numbers

import jax.numpy as _jnp
from jax import vmap
Expand Down Expand Up @@ -242,6 +243,77 @@ def take(
)


def _is_boolean_index(indices):
if isinstance(indices, (bool, _jnp.bool_)):
return True
if isinstance(indices, (list, tuple)):
return bool(indices) and _is_boolean_index(indices[0])
if isinstance(indices, _jnp.ndarray):
return indices.dtype in (_jnp.bool_, _jnp.uint8)
return False


def _is_iterable_index(indices):
if isinstance(indices, (list, tuple)):
return True
if isinstance(indices, _jnp.ndarray):
return indices.ndim > 0
return False


def _is_scalar_index(index):
return isinstance(index, _numbers.Integral) or (
isinstance(index, _jnp.ndarray) and index.ndim == 0
)


def _assignment_value_length(values):
if isinstance(values, (list, tuple)):
return len(values)
if isinstance(values, _jnp.ndarray) and values.ndim > 0:
return values.shape[0]
return 1


def _normalize_assignment_index(indices, ndim_x, axis=0):
if _is_boolean_index(indices):
return _jnp.asarray(indices), False, None

use_vectorization = _is_iterable_index(indices) and len(indices) < ndim_x
zip_indices = (
_is_iterable_index(indices)
and len(indices) > 0
and _is_iterable_index(indices[0])
)

if use_vectorization:
normalized = tuple(list(indices[:axis]) + [slice(None)] + list(indices[axis:]))
return normalized, True, None

if zip_indices:
normalized = tuple(_jnp.asarray(index_axis) for index_axis in zip(*indices))
return normalized, False, len(indices)

if isinstance(indices, list):
return _jnp.asarray(indices), False, len(indices)
if isinstance(indices, _jnp.ndarray) and indices.ndim > 0:
return indices, False, indices.shape[0]
if isinstance(indices, tuple):
if _builtins.all(_is_scalar_index(index) for index in indices):
return indices, False, 1
if indices and _is_iterable_index(indices[0]):
return indices, False, len(indices[0])
return indices, False, 1


def _validate_assignment_value_count(values, *, use_vectorization, len_indices):
if use_vectorization or len_indices is None:
return
len_values = _assignment_value_length(values)
if len_values > 1 and len_values != len_indices:
raise ValueError("Either one value or as many values as indices required")


def assignment(x, values, indices, axis=0):
"""
Assign values at given indices of an array using JAX.
Expand All @@ -265,20 +337,18 @@ def assignment(x, values, indices, axis=0):
x_new : JAX array, shape=[dim]
Copy of x with the values assigned at the given indices.
"""
# Ensure indices and values are iterable
if isinstance(indices, (int, tuple)):
indices = [indices]
if not isinstance(values, list):
values = [values] * len(indices)

# Check if we need to raise errors for mismatch in values and indices lengths
if len(values) != 1 and len(values) != len(indices):
raise ValueError("Either one value or as many values as indices required")

# Handling assignment with index update
x_new = x.at[indices].set(values)

return x_new
x = _jnp.asarray(x)
normalized_indices, use_vectorization, len_indices = _normalize_assignment_index(
indices,
x.ndim,
axis=axis,
)
_validate_assignment_value_count(
values,
use_vectorization=use_vectorization,
len_indices=len_indices,
)
return x.at[normalized_indices].set(values)


def assignment_by_sum(x, values, indices, axis=0):
Expand Down Expand Up @@ -308,21 +378,18 @@ def assignment_by_sum(x, values, indices, axis=0):
If a single value is provided, it is added at all the indices.
If a list is given, it must have the same length as indices.
"""
# Ensure indices and values are iterable
if isinstance(indices, (int, tuple)):
indices = [indices]
if not isinstance(values, list):
values = [values] * len(indices)

# Check if the number of values matches the number of indices, or there's exactly one value
if len(values) != 1 and len(values) != len(indices):
raise ValueError("Either one value or as many values as indices required")

# Handling addition with index update
for idx, val in zip(indices, values):
x = x.at[idx].add(val)

return x
x = _jnp.asarray(x)
normalized_indices, use_vectorization, len_indices = _normalize_assignment_index(
indices,
x.ndim,
axis=axis,
)
_validate_assignment_value_count(
values,
use_vectorization=use_vectorization,
len_indices=len_indices,
)
return x.at[normalized_indices].add(values)


def array_from_sparse(indices, data, target_shape):
Expand Down Expand Up @@ -521,4 +588,4 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
matrix = matrix.at[..., i, i].set(diag)
matrix = matrix.at[..., j, k].set(tri_upp)
matrix = matrix.at[..., k, j].set(tri_low)
return matrix
return matrix
59 changes: 59 additions & 0 deletions tests/backend_support/test_jax_assignment_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Regression tests for JAX backend assignment helper indexing."""

from __future__ import annotations

import importlib.util

import pytest

from tests.support.backend_runner import run_backend_code


@pytest.mark.backend_portable
def test_jax_assignment_accepts_numpy_style_advanced_indices():
if importlib.util.find_spec("jax") is None:
pytest.skip("JAX is not installed")

result = run_backend_code(
"jax",
"""
import pyrecest.backend as backend

x = backend.zeros((3, 3))
assigned = backend.assignment(x, [5.0, 7.0], [(0, 1), (1, 2)])
summed = backend.assignment_by_sum(backend.ones((3, 3)), [5.0, 7.0], [(0, 1), (1, 2)])

assert backend.to_numpy(assigned).tolist() == [[0.0, 5.0, 0.0], [0.0, 0.0, 7.0], [0.0, 0.0, 0.0]]
assert backend.to_numpy(summed).tolist() == [[1.0, 6.0, 1.0], [1.0, 1.0, 8.0], [1.0, 1.0, 1.0]]
""",
)

assert result.returncode == 0, result.stderr


@pytest.mark.backend_portable
def test_jax_assignment_accepts_list_and_boolean_indices():
if importlib.util.find_spec("jax") is None:
pytest.skip("JAX is not installed")

result = run_backend_code(
"jax",
"""
import pyrecest.backend as backend

vector = backend.zeros(3)
by_list = backend.assignment(vector, [4.0, 5.0], [0, 2])

matrix = backend.zeros((3, 3))
by_mask = backend.assignment(
matrix,
[1.0, 2.0],
[[True, False, False], [False, True, False], [False, False, False]],
)

assert backend.to_numpy(by_list).tolist() == [4.0, 0.0, 5.0]
assert backend.to_numpy(by_mask).tolist() == [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 0.0]]
""",
)

assert result.returncode == 0, result.stderr
Loading