Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create einsum operation #197

Merged
merged 39 commits into from
Jul 15, 2022
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
95f670f
Add einsum for np, torch, tf, jax backends
MilesCranmer Jul 4, 2022
2cfa0ef
Add function that converts to einsum pattern
MilesCranmer Jul 4, 2022
f499573
Add function that calls einsum for some tensors
MilesCranmer Jul 4, 2022
ce91d48
Include missing string import for einsum
MilesCranmer Jul 4, 2022
05fa3c4
Include einsum in main library
MilesCranmer Jul 4, 2022
6cd3b1b
Refactor test_einsum
MilesCranmer Jul 4, 2022
b167619
Add tests for functional einsum
MilesCranmer Jul 4, 2022
4589c6d
Create docstring for einsum
MilesCranmer Jul 4, 2022
7154b5a
Clean up doctest for einsum
MilesCranmer Jul 4, 2022
5672ea8
Allow duplicates in ParsedExpression for einsum
MilesCranmer Jul 4, 2022
c8173b0
Add tests for repeated indices in einsum
MilesCranmer Jul 4, 2022
31ae88b
Clean up einops.einsum docstrings
MilesCranmer Jul 4, 2022
29c8392
Improve readability of einsum implementation
MilesCranmer Jul 4, 2022
decdfbe
Implement einsum for additional backends
MilesCranmer Jul 4, 2022
7f4285c
Include einsum tests for additional backends
MilesCranmer Jul 4, 2022
f94234a
Expand einsum documentation
MilesCranmer Jul 4, 2022
fc93d3b
Fix einsum docstring example
MilesCranmer Jul 4, 2022
1060621
Test values passed through einsum
MilesCranmer Jul 4, 2022
5a7805f
Include helpful error message if -> missing
MilesCranmer Jul 4, 2022
b0d8c45
Remove unused class in einsum test
MilesCranmer Jul 4, 2022
d01ae74
Include many more einsum tests
MilesCranmer Jul 4, 2022
21ec6bd
Remove unnecessary strip() in einsum
MilesCranmer Jul 4, 2022
a8ec295
Error for singleton axes
MilesCranmer Jul 5, 2022
7590a77
Include symbolic backends in einsum test
MilesCranmer Jul 5, 2022
27e8ebd
Attempt inclusion of mxnet backend for einsum
MilesCranmer Jul 5, 2022
102e732
Remove mxnet einsum as non-functional
MilesCranmer Jul 5, 2022
fc076fa
Reduce indentation of einsum tests
MilesCranmer Jul 5, 2022
a12071d
Clean up axis name validation in einsum
MilesCranmer Jul 5, 2022
555484e
Fix error checking for einsum
MilesCranmer Jul 5, 2022
082aec0
Fix einsum for keras implementation
MilesCranmer Jul 5, 2022
eaa57c8
Add symbolic test for einsum in keras
MilesCranmer Jul 5, 2022
33ecafa
Clean up einsum tests
MilesCranmer Jul 5, 2022
7abee99
Test error paths of einsum pattern creation
MilesCranmer Jul 5, 2022
52da9e5
Change einsum syntax to `*tensors, pattern`
MilesCranmer Jul 9, 2022
43a8ef3
Test for bad order of tensors
MilesCranmer Jul 9, 2022
291a777
Validate actual error messages
MilesCranmer Jul 9, 2022
ed0038f
Remove unused error; add tests
MilesCranmer Jul 9, 2022
ceb33ab
Remove unused testing functionality
MilesCranmer Jul 9, 2022
5ab5c0d
Update ordering of params in docstring
MilesCranmer Jul 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions einops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class EinopsError(RuntimeError):
pass


__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError']
__all__ = ['rearrange', 'reduce', 'repeat', 'einsum',
'parse_shape', 'asnumpy', 'EinopsError']

from .einops import rearrange, reduce, repeat, parse_shape, asnumpy
from .einops import rearrange, reduce, repeat, einsum, parse_shape, asnumpy
35 changes: 31 additions & 4 deletions einops/_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import sys
import warnings
import functools

__author__ = 'Alex Rogozhnikov'

Expand Down Expand Up @@ -117,6 +118,9 @@ def layers(self):
def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)

def einsum(self, pattern, *x):
raise NotImplementedError("backend does not support einsum")


class UnknownSize:
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
Expand Down Expand Up @@ -168,6 +172,9 @@ def is_float_type(self, x):
def add_axis(self, x, new_position):
return self.np.expand_dims(x, new_position)

def einsum(self, pattern, *x):
return self.np.einsum(pattern, *x)


class JaxBackend(NumpyBackend):
framework_name = 'jax'
Expand Down Expand Up @@ -353,6 +360,9 @@ def layers(self):
from .layers import torch
return torch

def einsum(self, pattern, *x):
return self.torch.einsum(pattern, *x)


class CupyBackend(AbstractBackend):
framework_name = 'cupy'
Expand Down Expand Up @@ -385,6 +395,9 @@ def add_axis(self, x, new_position):
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')

def einsum(self, pattern, *x):
return self.cupy.einsum(pattern, *x)


class ChainerBackend(AbstractBackend):
framework_name = 'chainer'
Expand Down Expand Up @@ -428,6 +441,9 @@ def layers(self):
from .layers import chainer
return chainer

def einsum(self, pattern, *x):
return self.chainer.functions.einsum(pattern, *x)


class HashableTuple:
"""Overcomes non-hashability of symbolic elements"""
Expand Down Expand Up @@ -507,6 +523,9 @@ def layers(self):
from .layers import tensorflow
return tensorflow

def einsum(self, pattern, *x):
return self.tf.einsum(pattern, *x)


class KerasBackend(AbstractBackend):
framework_name = 'tensorflow.keras'
Expand All @@ -524,9 +543,8 @@ def create_symbol(self, shape):
return self.keras.Input(batch_shape=shape)

def eval_symbol(self, symbol, input_dict):
(variable, value), = input_dict
model = self.keras.models.Model(variable, symbol)
return model.predict_on_batch(value)
model = self.keras.models.Model([var for (var, _) in input_dict], symbol)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
return model.predict_on_batch([val for (_, val) in input_dict])

def arange(self, start, stop):
return self.K.arange(start, stop)
Expand Down Expand Up @@ -560,6 +578,12 @@ def layers(self):
from .layers import keras
return keras

def einsum(self, pattern, *x):
return self.tf.vectorized_map(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to understand why it looks so strange in tf.keras

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I was interpreting the symbolic (layer=True) backends correctly or not.

Basically, this einsum assumes the x tensors have a leading batch axis, which are assumed to not be specified in the pattern. I assumed this because the create_symbol method specifies the shape as a batch shape, rather than an absolute shape. Is that correct, or should it assume the pattern also specifies the batch axis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think my implementation has a potential issue: if one symbol is batched, and one symbol is not (like a weight matrix).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think the correct strategy is here? Should I avoid adding einsum for keras, since it is technically a layer=True backend?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer=True just refers to providing layers, it should not be related to any batch variables, and patterns should include batch variables. Anyway, I forgot keras now just a redirection to TF layers, so just excluded this part

functools.partial(self.tf.einsum, pattern),
*x
)


class OneFlowBackend(AbstractBackend):
framework_name = "oneflow"
Expand Down Expand Up @@ -621,4 +645,7 @@ def is_float_type(self, x):

def layers(self):
from .layers import oneflow
return oneflow
return oneflow

def einsum(self, pattern, *x):
return self.flow.einsum(pattern, *x)
142 changes: 142 additions & 0 deletions einops/einops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import itertools
import string
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
Expand Down Expand Up @@ -624,3 +625,144 @@ def asnumpy(tensor) -> 'numpy.ndarray':
`numpy.ndarray`, converted to numpy
"""
return get_backend(tensor).to_numpy(tensor)

def _validate_einsum_axis_name(axis_name):
if len(axis_name) == 0:
raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
if len(axis_name) > 1:
raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")

axis_name = axis_name[0]

if len(axis_name) == 0:
raise RuntimeError("Encountered empty axis name in einsum.")
if not isinstance(axis_name, str):
raise RuntimeError("Axis name in einsum must be a string.")
if axis_name[0].isdigit():
raise RuntimeError("Axis name in einsum must not start with a number.")


@functools.lru_cache(256)
def _compactify_pattern_for_einsum(pattern: str) -> str:
if "->" not in pattern:
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
# numpy allows this, so make sure users
# don't accidentally do something like this.
raise ValueError("Einsum pattern must contain '->'.")
lefts, right = pattern.split('->')
lefts = lefts.split(',')

lefts = [
ParsedExpression(left, allow_underscore=True, allow_duplicates=True)
MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
for left in lefts
]

right = ParsedExpression(right, allow_underscore=True)

# Start from a, and go up to Z
output_axis_names = string.ascii_letters
i = 0
axis_name_mapping = {}

left_patterns = []
for left in lefts:
left_pattern = ""
for raw_axis_name in left.composition:

if raw_axis_name == _ellipsis:
left_pattern += '...'
continue

_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]
if axis_name not in axis_name_mapping:
if i >= len(output_axis_names):
raise RuntimeError("Too many axes in einsum.")
axis_name_mapping[axis_name] = output_axis_names[i]
i += 1

left_pattern += axis_name_mapping[axis_name]
left_patterns.append(left_pattern)

output_pattern = ",".join(left_patterns) + "->"

for raw_axis_name in right.composition:
if raw_axis_name == _ellipsis:
output_pattern += '...'
continue

_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]

if axis_name not in axis_name_mapping:
raise RuntimeError("Unknown axis on right side of einsum.")

output_pattern += axis_name_mapping[axis_name]

return output_pattern


def einsum(pattern: str, *tensors: List[Tensor]) -> Tensor:
"""
einops.einsum calls einsum operations with einops-style named
axes indexing, computing tensor products with an arbitrary
number of tensors.

Note that unlike other einops functions, here you must give
the pattern before the tensor(s), rather than after.
Also, note that rearrange operations such as `"(batch chan) out"`,
or singleton axes `()`, are not currently supported.

Examples:

For a given pattern such as:
```python
>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einsum("a b c, c b d, a g k -> a b k", x, y, z)

```
the following formula is computed:
```tex
output[a, b, k] =
\sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
```
where the summation over `c`, `d`, and `g` is performed
because those axes names do not appear on the right-hand side.

Let's see some additional examples:
```python
# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum("batch h w, h w channel -> batch channel",
... batched_images, filters)
>>> result.shape
(128, 30)

# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum("out_dim in_dim, ... in_dim -> ... out_dim",
... weights, data)
>>> result.shape
(50, 30, 10)

# Matrix trace on a single tensor:
>>> matrix = np.random.randn(10, 10)
>>> result = einsum("i i ->", matrix)
>>> result.shape
()

```

Parameters:
pattern: string, einsum pattern, with commas
separating specifications for each tensor.
tensors: tensors of any supported library (numpy, tensorflow, pytorch, jax).

Returns:
Tensor of the same type as input, after processing with einsum.

MilesCranmer marked this conversation as resolved.
Show resolved Hide resolved
"""
pattern = _compactify_pattern_for_einsum(pattern)
return get_backend(tensors[0]).einsum(pattern, *tensors)
5 changes: 3 additions & 2 deletions einops/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ class ParsedExpression:
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression, *, allow_underscore: bool = False):
def __init__(self, expression, *, allow_underscore: bool = False,
allow_duplicates: bool = False):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
Expand All @@ -48,7 +49,7 @@ def __init__(self, expression, *, allow_underscore: bool = False):
def add_axis_name(x):
if x is not None:
if x in self.identifiers:
if not (allow_underscore and x == "_"):
if not (allow_underscore and x == "_") and not allow_duplicates:
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)
Expand Down