Skip to content

34j/array-api-shape-check

array-api-shape-check

CI Status Documentation Status Test coverage percentage

uv Ruff prek

PyPI Version Supported Python versions License


Documentation: https://array-api-shape-check.readthedocs.io

Source Code: https://github.com/34j/array-api-shape-check


Check shapes of input arrays easily.

Installation

Install this via pip (or your favourite package manager):

pip install array-api-shape-check

Usage

>>> from array_api_shape_check import check_shapes
>>> info = check_shapes("ij,*k*l,*li", (1, 4), (5, 6, 7), (1, 7, 3))
>>> info.all
((i:1->3, j:4), (*k:(5,), *l:(6, 7)), (*l:(1, 7)->(6, 7), i:3))
>>> info.unique
{'i': i:3, 'j': j:4, 'k': *k:(5,), 'l': *l:(6, 7)}

Internally check_shapes() calls parse_variable_ndim(), which determines the number of dimensions for variable subscripts by least squares. If this is successful, checks if each subscript is consistent, then finnaly raises error for all inconsistencies at once.

Diving into the details of the first item:

>>> item = info.all[0][0]
>>> item.name # the name of the subscript
'i'
>>> item.is_variable # whether the subscript is variable (starts with "*")
False
>>> item.shape_current # the current shape of the subscript
(1,)
>>> item.shape_broadcasted # the broadcasted shape of the subscript
(3,)

Not enough information to determine variable subscript ndims:

>>> import pytest
>>> from array_api_shape_check import InconsistentNdimErrorMultipleSolutions, InconsistentNdimErrorNoSolutions, InconsistentShapeError
>>> with pytest.raises(InconsistentNdimErrorMultipleSolutions, match="number of variables"):
...     check_shapes("*i*j", (1, 1))
>>> with pytest.raises(InconsistentNdimErrorMultipleSolutions, match="rank"):
...     check_shapes("*i*j,*i*j", (1, 1), (1, 1))

No solution to determine variable subscript ndims:

>>> with pytest.raises(InconsistentNdimErrorNoSolutions, match="residuals"):
...     check_shapes("*i,*i", (1, 1), (1, 1, 1))
>>> with pytest.raises(InconsistentNdimErrorNoSolutions, match="negative"):
...     check_shapes("*ij", ())

Does not match:

>>> with pytest.raises(InconsistentShapeError):
...     check_shapes("ij,*k*l,*li", (3, 4), (5, 6), (1, 7, 3))

Contributors ✨

Thanks goes to these wonderful people (emoji key):

This project follows the all-contributors specification. Contributions of any kind welcome!

Credits

Copier

This package was created with Copier and the browniebroke/pypackage-template project template.

About

Check shapes of input arrays easily

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Packages

 
 
 

Contributors