Skip to content

Commit

Permalink
add reduce
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Jan 13, 2022
1 parent 9c9a591 commit 26e8bd6
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 56 deletions.
10 changes: 7 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ As stated in their website:
> xarray makes working with multi-dimensional labeled arrays simple, efficient and fun!
The code is often more verbose, but it is generally because it is clearer and thus less error prone
and intuitive. Here are some examples of such trade-off:
and more intuitive.
Here are some examples of such trade-off where we believe the increased clarity is worth
the extra characters:


| numpy | xarray |
|---------|----------|
| `a[2, 5]` | `da.sel(drug="paracetamol", subject=5)` |
| `a.mean(axis=(0, 1))` | `da.mean(dim=("chain", "draw"))` |
| `` | `` |
| `a.reshape((-1, 10)) | `da.stack(sample=("chain", "draw"))` |
| `a.transpose(2, 0, 1)` | `da.transpose("drug", "chain", "draw")` |

In some other cases however, using xarray can result in overly verbose code
that often also becomes less clear. `xarray-einstats` provides wrappers
Expand Down Expand Up @@ -83,7 +87,7 @@ Dimensions without coordinates: dim_plot
```

### einops
**only rearrange wrapped for now**
**repeat wrapper still missing**

[einops](https://einops.rocks/) uses a convenient notation inspired in
Einstein notation to specify operations on multidimensional arrays.
Expand Down
3 changes: 3 additions & 0 deletions docs/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
```

```{eval-rst}
.. currentmodule:: xarray_einstats
.. autosummary::
:toctree: generated/
Expand All @@ -25,5 +26,7 @@
rearrange
raw_rearrange
reduce
raw_reduce
```

17 changes: 16 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,25 @@

exclude_patterns = ["Thumbs.db", ".DS_Store", ".ipynb_checkpoints"]

# The reST default role (used for this markup: `text`) to use for all documents.
default_role = "code"

# If true, '()' will be appended to :func: etc. cross-reference text.
add_function_parentheses = False

# -- Options for extensions

myst_enable_extensions = ["colon_fence", "deflist", "dollarmath", "amsmath"]

autosummary_generate = True
autodoc_typehints = 'none'

numpydoc_xref_param_type = True
numpydoc_xref_ignore = {"of", "optional"}
numpydoc_xref_aliases = {
"DataArray": ":class:`xarray.DataArray`",
}


# -- Options for HTML output

Expand All @@ -55,7 +70,7 @@
intersphinx_mapping = {
"dask": ("https://docs.dask.org/en/latest/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"xarray": ("http://xarray.pydata.org/en/stable/", None),
}

213 changes: 163 additions & 50 deletions src/xarray_einstats/einops.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""einops module."""
"""Wrappers for `einops <https://einops.rocks/>`_."""
import einops
import xarray as xr

__all__ = ["rearrange", "raw_rearrange"]
__all__ = ["rearrange", "raw_rearrange", "reduce", "raw_reduce"]


class DimHandler:
Expand Down Expand Up @@ -56,8 +56,50 @@ def process_pattern_list(redims, handler, allow_dict=True, allow_list=True):
return out, out_names, " ".join(txt)


def translate_pattern(pattern):
dims = []
current_dim = ""
current_block = []
parsing_block = 0 # 0=no block, 1=block, 2=just closed, waiting for key
parsing_key = False
for char in pattern.strip() + " ":
if char == " ":
if parsing_key:
if current_dim:
dims.append({current_dim: current_block})
else:
dims.append(current_block)
current_block = []
parsing_key = False
parsing_block = False
elif not current_dim:
continue
elif parsing_block:
current_block.append(current_dim)
else:
dims.append(current_dim)
current_dim = ""
elif char == ")":
if parsing_block:
parsing_block = False
parsing_key = True
if current_dim:
current_block.append(current_dim)
current_dim = ""
else:
raise ValueError("unmatched parenthesis")
elif char == "(":
parsing_block = 1
elif char == "=":
if not parsing_key:
raise ValueError("= sign must follow a closing parenthesis )")
else:
current_dim += char
return dims


def rearrange(da, out_dims, in_dims=None, **kwargs):
"""Wrap einops.rearrange.
"""Wrap `einops.rearrange <https://einops.rocks/api/rearrange/>`_.
Parameters
----------
Expand All @@ -69,7 +111,6 @@ def rearrange(da, out_dims, in_dims=None, **kwargs):
in_dims : list of str or dict, optional
The input pattern for the dimensions.
This is only necessary if you want to split some dimensions.
In einops, the left side of the pattern serves two goals
kwargs : dict, optional
kwargs with key equal to dimension names in ``out_dims``
(that is, strings or dict keys) are passed to einops.rearrange
Expand All @@ -84,7 +125,7 @@ def rearrange(da, out_dims, in_dims=None, **kwargs):
See Also
--------
xarray_einstats.raw_rearrange:
xarray_einstats.einops.raw_rearrange:
Cruder wrapper of einops.rearrange, allowed characters in dimension names are restricted
xarray.DataArray.transpose, xarray.Dataset.transpose
xarray.DataArray.stack, xarray.Dataset.stack
Expand Down Expand Up @@ -124,50 +165,8 @@ def rearrange(da, out_dims, in_dims=None, **kwargs):
)


def translate_pattern(pattern):
dims = []
current_dim = ""
current_block = []
parsing_block = 0 # 0=no block, 1=block, 2=just closed, waiting for key
parsing_key = False
for char in pattern.strip() + " ":
if char == " ":
if parsing_key:
if current_dim:
dims.append({current_dim: current_block})
else:
dims.append(current_block)
current_block = []
parsing_key = False
parsing_block = False
elif not current_dim:
continue
elif parsing_block:
current_block.append(current_dim)
else:
dims.append(current_dim)
current_dim = ""
elif char == ")":
if parsing_block:
parsing_block = False
parsing_key = True
if current_dim:
current_block.append(current_dim)
current_dim = ""
else:
raise ValueError("unmatched parenthesis")
elif char == "(":
parsing_block = 1
elif char == "=":
if not parsing_key:
raise ValueError("= sign must follow a closing parenthesis )")
else:
current_dim += char
return dims


def raw_rearrange(da, pattern, **kwargs):
"""Crudely wrap einops.rearrange.
"""Crudely wrap `einops.rearrange <https://einops.rocks/api/rearrange/>`_.
Wrapper around einops.rearrange with a very similar syntax.
Spaces, parenthesis ``()`` and `->` are not allowed in dimension names.
Expand All @@ -187,15 +186,15 @@ def raw_rearrange(da, pattern, **kwargs):
a default name.
kwargs : dict, optional
Passed to :func:`xarray_einstats.rearrange`
Passed to :func:`xarray_einstats.einops.rearrange`
Returns
-------
xarray.DataArray
See Also
--------
xarray_einstats.rearrange:
xarray_einstats.einops.rearrange:
More flexible and powerful wrapper over einops.rearrange. It is also more verbose.
"""
if "->" in pattern:
Expand All @@ -206,3 +205,117 @@ def raw_rearrange(da, pattern, **kwargs):
in_dims = None
out_dims = translate_pattern(out_pattern)
return rearrange(da, out_dims=out_dims, in_dims=in_dims, **kwargs)


def reduce(da, reduction, out_dims, in_dims=None, **kwargs):
"""Wrap `einops.reduce <https://einops.rocks/api/reduce/>`_.
Parameters
----------
da : xarray.DataArray
Input DataArray to be reduced
reduction : string or callable
One of available reductions ('min', 'max', 'sum', 'mean', 'prod') by ``einops.reduce``,
case-sensitive. Alternatively, a callable ``f(tensor, reduced_axes) -> tensor``
can be provided. ``reduced_axes`` are passed as a list of int.
out_dims : list of str, list or dict
The output pattern for the dimensions.
The dimensions present in
in_dims : list of str or dict, optional
The input pattern for the dimensions.
This is only necessary if you want to split some dimensions.
kwargs : dict, optional
kwargs with key equal to dimension names in ``out_dims``
(that is, strings or dict keys) are passed to einops.rearrange
the rest of keys are passed to :func:`xarray.apply_ufunc`
Notes
-----
Unlike for general xarray objects, where dimension
names can be :term:`hashable <xarray:name>` here
dimension names are not recommended but required to be
strings.
See Also
--------
xarray_einstats.einops.raw_reduce:
Cruder wrapper of einops.rearrange, allowed characters in dimension names are restricted
xarray_einstats.einops.rearrange, xarray_einstats.einops.raw_rearrange
"""
da_dims = da.dims

handler = DimHandler()
if in_dims is None:
in_dims = []
in_names = []
in_pattern = ""
else:
in_dims, in_names, in_pattern = process_pattern_list(
in_dims, handler=handler, allow_list=False
)
# note, not using sets for da_dims to avoid transpositions on missing variables,
# if they wanted to transpose those they would not be missing variables
out_dims, out_names, out_pattern = process_pattern_list(out_dims, handler=handler)
missing_in_dims = [dim for dim in da_dims if dim not in in_names]
pattern = f"{handler.get_names(missing_in_dims)} {in_pattern} -> {out_pattern}"

all_dims = set(out_dims + out_names + in_names + in_dims)
axes_lengths = {handler.rename_kwarg(k): v for k, v in kwargs.items() if k in all_dims}
kwargs = {k: v for k, v in kwargs.items() if k not in all_dims}
return xr.apply_ufunc(
einops.reduce,
da,
pattern,
reduction,
input_core_dims=[missing_in_dims + in_names, [], []],
output_core_dims=[out_names],
kwargs=axes_lengths,
**kwargs,
)


def raw_reduce(da, pattern, reduction, **kwargs):
"""Crudely wrap `einops.reduce <https://einops.rocks/api/reduce/>`_.
Wrapper around einops.reduce with a very similar syntax.
Spaces, parenthesis ``()`` and `->` are not allowed in dimension names.
Parameters
----------
da : xarray.DataArray
Input array
pattern : string
Pattern string. Same syntax as patterns in einops with two
caveats:
* Unless splitting or stacking, you must use the actual dimension names.
* When splitting or stacking you can use `(dim1 dim2)=dim`. This is `necessary`
for the left hand side as it identifies the dimension to split, and
optional on the right hand side, if omitted the stacked dimension will be given
a default name.
reduction : string or callable
One of available reductions ('min', 'max', 'sum', 'mean', 'prod') by ``einops.reduce``,
case-sensitive. Alternatively, a callable ``f(tensor, reduced_axes) -> tensor``
can be provided. ``reduced_axes`` are passed as a list of int.
kwargs : dict, optional
Passed to :func:`xarray_einstats.einops.reduce`
Returns
-------
xarray.DataArray
See Also
--------
xarray_einstats.einops.reduce:
More flexible and powerful wrapper over einops.reduce. It is also more verbose.
xarray_einstats.einops.rename_kwarg, xarray_einstats.einops.raw_rearrange
"""
if "->" in pattern:
in_pattern, out_pattern = pattern.split("->")
in_dims = translate_pattern(in_pattern)
else:
out_pattern = pattern
in_dims = None
out_dims = translate_pattern(out_pattern)
return reduce(da, reduction, out_dims=out_dims, in_dims=in_dims, **kwargs)
Loading

0 comments on commit 26e8bd6

Please sign in to comment.