Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create einsum operation #197

Merged
merged 39 commits into from Jul 15, 2022
Merged

Conversation

MilesCranmer
Copy link
Contributor

@MilesCranmer MilesCranmer commented Jul 4, 2022

This creates the functional einsum function as requested on #73. CC @arogozhnikov @cgarciae

The current implementation simply parses the string and converts it to einsum notation by mapping axis names to single characters (I use string.ascii_letters, starting from a, b, c etc).

Currently, it has the following features:

  • Supports the backends: tensorflow, numpy, jax, pytorch, chainer, oneflow, keras, cupy.
  • Allows for an arbitrary number of tensors passed.
  • Allows ellipsis specification, including for multiple tensors, so long as it is provided on both the left and the right of the ->.

It does not currently support

  • Reshape operations, such as "(batch channel) feature, feature -> batch channel".
  • Custom reduction operations.

These could be added later if desired. Some backends do not support custom reductions in their einsum implementations so it will be a bit more work.

I also added a docstring and some unittests (in tests/test_einsum.py).

Here are some examples of use, with the numpy backend:

# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum(batched_images, filters,
...                 "batch h w, h w channel -> batch channel") 

>>> result.shape
(128, 30)

# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum(weights, data, 
...                 "out_dim in_dim, ... in_dim -> ... out_dim")
>>> result.shape
(50, 30, 10)

Note that the number of spaces next to the comma above are arbitrary, you could do either "in_dim, ..." or "in_dim , ..." - both will work.

Eager to hear feedback on this!

Cheers,
Miles


Edit 1: Got working for repeat indices on one side (as used in, e.g., trace).
Edit 2: Added support for chainer, oneflow, cupy, tensorflow.keras.
Edit 3: Added many more tests, some mirroring those used in the np.einsum tests.
Edit 4: More and more unit tests.
Edit 5: Tweaked the syntax to have tensors first, pattern second. Adapted tests, and added new validation for order of arguments.

@MilesCranmer MilesCranmer mentioned this pull request Jul 4, 2022
@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Jul 4, 2022

I think implementing the rearrange operations inside shouldn't be too bad: since einsum doesn't require any specific order of indices, you could call

for (tensor, left_expression) in zip(tensors, left_expressions):
    axis_names = ...
    tensor = rearrange(tensor, left_expression + "->" + " ".join(axis_names))

for every left_expression. Then, you would pass " ".join(axis_names) back to einops.einsum.

Then, you could do a similar rearrange on the output expression.

What do you think @arogozhnikov? (for a future PR, of course)

@MilesCranmer
Copy link
Contributor Author

MilesCranmer commented Jul 4, 2022

Actually I realized this doesn't work with repeat variable names. e.g., if I want to compute the trace of a tensor:

einsum("index index -> ", np.ones((5, 5)))

ParsedExpression doesn't allow duplicate dimensions so this doesn't work. I guess I could modify it to allow this.

Edit: fixed with allow_duplicates argument to ParsedExpression.

Copy link
Owner

@arogozhnikov arogozhnikov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool, thanks Miles for very clean PR and collecting test suite

I've left some thoughts and request for two kinds of tests:

  1. fail on parsing of features that aren't yet supported
  2. add tests for symbolic backends. Probably latter would only include tf, but still

einops/einops.py Outdated Show resolved Hide resolved
einops/einops.py Show resolved Hide resolved
einops/einops.py Show resolved Hide resolved
einops/einops.py Outdated Show resolved Hide resolved
einops/einops.py Show resolved Hide resolved
tests/test_einsum.py Show resolved Hide resolved
tests/test_einsum.py Outdated Show resolved Hide resolved
@MilesCranmer
Copy link
Contributor Author

Okay, all suggestions implemented. Let me know what you think.

@MilesCranmer
Copy link
Contributor Author

Okay everything is implemented for the new syntax:

y = einsum(x, x, "i j, i k -> j k")

I also added new validation checks for the argument order, and corresponding unit-tests.

In the unittests, I also now check the specific message of each error, rather than the error type.

Let me know what you think.

@arogozhnikov
Copy link
Owner

arogozhnikov commented Jul 10, 2022

for every left_expression. Then, you would pass " ".join(axis_names) back to einops.einsum.
Then, you could do a similar rearrange on the output expression.

It's trickier since you want some of axes to be derived from the inputs shapes.
In the examples you previously posted it could be like (i j) k, i -> j k, so first shape of the second argument should be parsed. For the output that's actually straighforward and just applying a pattern would always work

Not relevant for the PR. Just commenting since you asked

@@ -560,6 +578,12 @@ def layers(self):
from .layers import keras
return keras

def einsum(self, pattern, *x):
return self.tf.vectorized_map(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to understand why it looks so strange in tf.keras

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I was interpreting the symbolic (layer=True) backends correctly or not.

Basically, this einsum assumes the x tensors have a leading batch axis, which are assumed to not be specified in the pattern. I assumed this because the create_symbol method specifies the shape as a batch shape, rather than an absolute shape. Is that correct, or should it assume the pattern also specifies the batch axis?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think my implementation has a potential issue: if one symbol is batched, and one symbol is not (like a weight matrix).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think the correct strategy is here? Should I avoid adding einsum for keras, since it is technically a layer=True backend?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer=True just refers to providing layers, it should not be related to any batch variables, and patterns should include batch variables. Anyway, I forgot keras now just a redirection to TF layers, so just excluded this part

@arogozhnikov arogozhnikov merged commit e168125 into arogozhnikov:master Jul 15, 2022
@arogozhnikov
Copy link
Owner

PR is merged, made very minor changes.
thank you for paying attention to details and keeping pushing this!

@MilesCranmer
Copy link
Contributor Author

Awesome!! Great to hear.

@gerdm
Copy link

gerdm commented Oct 4, 2022

This is a great PR!
+1 for “ Custom reduction operations”. Is there anyone already working on this?

@alok
Copy link

alok commented May 9, 2023

I'm interested in adding rearrange support. Any pointers?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants