Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

einops.einsum #73

Closed
cgarciae opened this issue Oct 20, 2020 · 26 comments
Closed

einops.einsum #73

cgarciae opened this issue Oct 20, 2020 · 26 comments

Comments

@cgarciae
Copy link
Contributor

Hey! Loving einops, so much that now I feel a bit sad about standard einsum not being able to use descriptive names for dimensions. It would be amazing if einops implemented einsum with the same conveniences.

@arogozhnikov
Copy link
Owner

arogozhnikov commented Oct 22, 2020

I miss that sometimes as well, but so far refrain from wrapping existing functionality just for the sake of this feature.

Let me keep this issue open to collect cases when people really need this being reimplemented in einops

@MilesCranmer
Copy link
Contributor

+1 to this! My code right now is a mix of einops and torch.einsum, but I think it would be very nice/consistent if it was all the same syntax through einops. As the name einops suggests, I half expected einsum to be included when I tried it the first time.

I think a wrapper could start off by just being something that takes, e.g.,

einops.einsum(x, y, 'time i j, time i k -> time i k')

and simply parses it into usual einsum notation (for each corresponding backend):

torch.einsum("aij,ajk->aik", x, y)

Maybe eventually there could be cool features to consider related to the other syntax innovations in this awesome package.

Cheers!
Miles

@MilesCranmer
Copy link
Contributor

One useful feature not found in existing libraries could be a combined reshape->einsum represented as a single einops expression, like so:

einops.sum(x, y, '(i j), i j -> i')

This would first rearrange x from (i j) -> i j, using y to infer the layout of x via i=y.shape[0], and then do dot products along each row: `'ij,ij->i'.

@arogozhnikov
Copy link
Owner

@MilesCranmer this last thing is already in plans, you can read RFC #71

@cgarciae I think it makes sense to look that RFC, as to my mind it covers large fraction of use cases for einsum in deep learning.

@cgarciae
Copy link
Contributor Author

cgarciae commented Nov 3, 2020

Hey @arogozhnikov, I'd already seen the WeightedEinsum layer, its very nice! I still think it would useful to have the functional version for the operation.

@alok
Copy link

alok commented Mar 16, 2022

I also would find einsum with more descriptive names useful. Jamming together everything into single letters, often with no whitespace, is unpleasant once there's 3+ indices to consider. I was writing out the multi-head attention from the transformer paper to practice einops and ran into this when forced to use regular einsum.

Just noted that http://einops.rocks/pytorch-examples.html has It would be great to use just 'b c1 h w,b c2 h w->b c1 c2', but einsum supports only one-letter axes. Assuming @arogozhnikov wrote that, seems like evidence for the feature.

@MilesCranmer
Copy link
Contributor

Would you be open to a functional version @arogozhnikov? I can help add it but I want to confirm your approval first.

@arogozhnikov
Copy link
Owner

@MilesCranmer
Do you have something like Tim's zweisum in mind?

https://gist.github.com/rockt/a3191f517728ea9a136a204f578d27c8

I just want to discuss some issues before they appear: parsing likely should be cached, and standard backend-guessing should be applied.

Torch's scripting doesn't make any friendship with caching, so likely a layer would be required as well.

@MilesCranmer
Copy link
Contributor

I think that syntax would be the desired style. It would also be nice to have rearrange built-in as a preprocessing step, so that one could do things like:

einops.einsum("batch (height width), height width channel -> batch channel", x, filter)

but maybe for this you would rather have the user split into two separate operations, one rearrange, and one einsum?

Cheers,
Miles

@arogozhnikov
Copy link
Owner

There are several sides:

  • interface-wise composing and decomposing on the fly is the right einopsy way
  • speed-wise (in some frameworks) this will incur overhead for cases when rearrangement is not required
  • compatibility: validation of input shapes without rearrangement can be deferred to downstream frameworks, which means function can work in strange scenarios like symbolic tracing without shape (e.g. torch.fx, tensor variables without assigned shapes).

In total, I'd better have a simple version without compositions/decompositions + a corresponding layer, and leave extending support for future

@MilesCranmer
Copy link
Contributor

PR ready for comments! #197

@MilesCranmer
Copy link
Contributor

Any opinions on syntax? We are discussing the following options in the PR. The following compares reduce and einsum syntax:

y = reduce(x, "i j -> i")
y = einsum("i j -> i", x)  # same as above
y2 = einsum("i j, i j -> i", x, x)
y = reduce(x, "i j -> i")
y = einsum(x, "i j -> i")
y2 = einsum(x, x, "i j, i j -> i")
y = reduce(x, "i j -> i")
y = einsum("i <- i j", x)
y2 = einsum("i <- i j, i j", x, x)

(1) is technically the same as einsum in other packages, but differs from the einops style.

(2) is my preference, although the potential issue is that you would have a single mixed-type *args parameter, and unpack it inside the function to tensors and pattern, always assuming the pattern is last.

(3) is another option, putting the tensors at the end but keeping the indices on the same side as the tensor. (x is on the right of the pattern, and has indices on the right; likewise for y).

@alok
Copy link

alok commented Jul 7, 2022

I like 2 but if it causes issues 1 is fine. 3 looks odd

@MilesCranmer
Copy link
Contributor

Pinging this thread - let me know of any other opinions. I'll adapt the PR to use (2) otherwise.

@arogozhnikov
Copy link
Owner

Option 2 looks better, but can't work with type hints, and that's a strong argument against.

@MilesCranmer
Copy link
Contributor

Up to you, I am happy to implement any option.

To get 2 working with type hints, would the following be an option?

@typing.overload
def einsum(tensor: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ...


def einsum(*tensors_and_pattern) -> Tensor:
    ...  # Actual function

This seems to give me the correct hints.

If you are thinking about using the type hints for debugging with mypy, would the following modification give the type checker enough information?

def f(*args):
    args[:-1]: List[Tensor]
    args[-1]: str
    ...

@arogozhnikov
Copy link
Owner

this will correctly work with IDEs (pycharm / vscode), and should pass mypy (though did not check latter).

There is a second part when typing matters: scripting and optimization. If we resolve all problems with dynamic dispatching (e.g. by array-api), pytorch and others still won't be able to script such function because args would have type List[Union[str, Tensor]] and such types seem out of consideration.

Optimization (still anticipated feature of cpython 3.12) as I understand would use similar mechanics to compile some functions based on type hinting.

@arogozhnikov
Copy link
Owner

arogozhnikov commented Jul 10, 2022

Looking at examples

# your example
result = einsum(batched_images, filters,
                "batch h w, h w channel -> batch channel") 

# my example
result = einsum(activations, activations,
                "b h w head c, b h2 w2 head c -> b h w h2 w2 head") 

I believe that should be a good habit to write pattern and tensors on different lines and in this case position of pattern (first or second line = first or last) shouldn't play a big role.

@MilesCranmer
Copy link
Contributor

Good point, I didn't realize this aspect about how compilation libraries work. Will think about this more...

@MilesCranmer
Copy link
Contributor

MilesCranmer commented Jul 10, 2022

A fourth option for syntax would be the following:

result = einsum([activations, activations],
                "b h w head c, b h2 w2 head c -> b h w h2 w2 head")

i.e., the first argument is a list of tensors.

This gives type stability from List[Tensor]. However, the downside is that it conflicts with existing uses of passed lists in einops, such as how rearrange([x, x], ...) is implicitly a stack(..., dim=0), resulting in an additional dimension in a single pattern, rather than expecting two separate patterns. Maybe this is okay though?


I tried scripting this in PyTorch, but a few other things broke: lru_cache and ParsedExpression don't seem compatible. I would assume that a library based on tracing like torch.jit.trace or JAX would do fine here. Is the expectation that these internal functions which are not currently compatible would be patched in the future, and it is important for the syntax itself to be ready for scripting?

@MilesCranmer
Copy link
Contributor

MilesCranmer commented Jul 10, 2022

Actually, even torch.einsum is not set up for scripting:

>>> torch.jit.script(torch.einsum)
NotSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults:

Maybe the syntax (2) is fine as-is?

I see the PyTorch version is given as def einsum(*args: Any) -> Tensor:, which doesn't give useful type hints.

So perhaps we could use syntax (2), with typing.overload to get type hints?


Edit: It looks like you can get torch.einsum to work in a script if called internal to the function (but you can't do this with einops.reduce or einops.einsum). This doesn't make sense since torch.einsum has a variable number of arguments...

Edit 2: In PyTorch, they dispatch to this internal library called "_VariableFunctions": https://github.com/pytorch/pytorch/blob/1022443168b5fad55bbd03d087abf574c9d2e9df/torch/_VF.py.

Anyways, I think with variable number of arguments (in any part of einops), you may have to do tracing rather than scripting.

@MilesCranmer
Copy link
Contributor

Pinging this thread. I'm fine to go ahead with einops.einsum + typing.overload strategy, if that sounds good to everyone.

@arogozhnikov
Copy link
Owner

i.e., the first argument is a list of tensors.

Lists are reserved for semantically different thing (e.g. see indexing or 'stacking' with rearrange).

To weight on pattern-first and pattern-last I want to go through a set of larger examples where op is used in a context. Sorry still didn't get to it, but I will try today.

@arogozhnikov
Copy link
Owner

arogozhnikov commented Jul 14, 2022

and the winner is ... <pretends he didn't read contents of the envelope> pattern-last order of arguments 🎉

With pattern-last it is easier to track the flow of data (which variables define which), while reading/analysis of pattern can be delayed to the moment when general flow of the program is clear.

I believe this overweights different technical downsides. Will merge tomorrow

Comment about torch.einsum: it is scriptable in the context (like, if you pass arguments and patterns, i.e. specify operation, it is scriptable).

@NightMachinery
Copy link

You can also add the other orders with different names. E.g., einsum2. I have done this for common CLI tools such as cp and mv, and I use both functions depending on which is more convenient. Of course, people can do this themselves, but it would be easier to use if it is done in the original library.

@arogozhnikov
Copy link
Owner

einops.einsum is live since 0.5.0, closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants