Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add einsum #7526

Merged
merged 33 commits into from
Mar 6, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
043e47e
add einsum implementation
Ldpe2G Feb 17, 2022
a631682
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Feb 17, 2022
6c2c661
reformat
Ldpe2G Feb 17, 2022
a3ca9a8
fix ci docs build error
Ldpe2G Feb 18, 2022
b85d308
add more test cases from DALLE-pytorch and alphafold repos
Ldpe2G Feb 18, 2022
73c4e8a
fix docs format
Ldpe2G Feb 18, 2022
8dc2ad6
add detailed explaination for better understanding how einsum works
Ldpe2G Feb 20, 2022
375d34d
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Feb 20, 2022
a1396ef
fix error
Ldpe2G Feb 21, 2022
2376b04
refine
Ldpe2G Feb 21, 2022
4c18c83
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Feb 22, 2022
b2a776b
add eager global tests
Ldpe2G Feb 22, 2022
2a54022
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Feb 27, 2022
a13bc09
merge master
Ldpe2G Feb 27, 2022
4322e4e
try to fix ci 2n4d error
Ldpe2G Feb 28, 2022
df6241c
move test_einsum.py to expensive folder
Ldpe2G Mar 2, 2022
bfc4021
Merge branch 'master' of https://github.com/Oneflow-Inc/oneflow into …
Ldpe2G Mar 2, 2022
e29a32b
add more test cases
Ldpe2G Mar 3, 2022
47c0a8b
Merge branch 'master' into dev_add_einsum
Ldpe2G Mar 3, 2022
a9f6841
add testcases from openflod repo
Ldpe2G Mar 4, 2022
a876999
Merge branch 'master' into dev_add_einsum
Ldpe2G Mar 4, 2022
8f1de39
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 4, 2022
2b5cf8c
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 4, 2022
45b1e98
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 4, 2022
75e2562
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 4, 2022
1ea2a04
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 4, 2022
a26e828
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 5, 2022
bbe3ea7
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 5, 2022
83aefcb
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 5, 2022
0bf0b16
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 5, 2022
06091b7
Merge branch 'master' into dev_add_einsum
oneflow-ci-bot Mar 5, 2022
a73d30c
Merge branch 'master' into dev_add_einsum
Ldpe2G Mar 6, 2022
3a44f35
Merge branch 'master' into dev_add_einsum
mergify[bot] Mar 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ oneflow
as_strided,
div,
dot,
eq,
eq,
einsum,
equal,
expand,
eye,
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/common/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ inline uint32_t NewRandomSeed() {
#define DIM_SEQ \
OF_PP_MAKE_TUPLE_SEQ(1) \
OF_PP_MAKE_TUPLE_SEQ(2) \
OF_PP_MAKE_TUPLE_SEQ(3) OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6)
OF_PP_MAKE_TUPLE_SEQ(3) \
OF_PP_MAKE_TUPLE_SEQ(4) OF_PP_MAKE_TUPLE_SEQ(5) OF_PP_MAKE_TUPLE_SEQ(6) OF_PP_MAKE_TUPLE_SEQ(7)

#define BOOL_SEQ (true)(false)

Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2008,3 +2008,7 @@
- name: "cumprod_grad"
signature: "Tensor (Tensor input, Tensor y, Tensor x, Int64 dim) => CumprodGrad"
bind_python: False

- name: "einsum"
signature: "Tensor (String equation, TensorTuple operands) => EinSum"
bind_python: True
574 changes: 574 additions & 0 deletions oneflow/core/functional/impl/math_functor.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions oneflow/core/ndarray/xpu_broadcast_ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(2);
SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(3);
SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(4);
SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(5);
SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL(6);
#undef SPECIALIZE_XPU_BROADCAST_NDARRAY_UTIL
#undef IMPLACE_SET_SRC_COORD

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ndarray/xpu_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ SPECIALIZE_XPU_SHAPE_UTIL(1);
SPECIALIZE_XPU_SHAPE_UTIL(2);
SPECIALIZE_XPU_SHAPE_UTIL(3);
SPECIALIZE_XPU_SHAPE_UTIL(4);
SPECIALIZE_XPU_SHAPE_UTIL(5);
#undef SPECIALIZE_XPU_SHAPE_UTIL
#undef EXTRACT_COORD
#undef COORD_MUL_STRIDE
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ndarray/xpu_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ namespace oneflow {
#define GET_SEQ_3 GET_SEQ_2 OF_PP_MAKE_TUPLE_SEQ(3)
#define GET_SEQ_4 GET_SEQ_3 OF_PP_MAKE_TUPLE_SEQ(4)
#define GET_SEQ_5 GET_SEQ_4 OF_PP_MAKE_TUPLE_SEQ(5)
#define GET_SEQ_6 GET_SEQ_5 OF_PP_MAKE_TUPLE_SEQ(6)

} // namespace oneflow

Expand Down
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def atexit_hook(hook):
adaptive_avg_pool2d,
adaptive_avg_pool3d,
)
from oneflow.nn.modules.einsum import einsum_op as einsum
from oneflow.nn.modules.is_tensor import is_tensor_op as is_tensor
from oneflow.nn.modules.arange import arange_op as arange
from oneflow.nn.modules.linspace import linspace_op as linspace
Expand Down
1 change: 1 addition & 0 deletions python/oneflow/framework/docstr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,5 @@
from .sort import *
from .is_floating_point import *
from .where import *
from .einsum import *
from .oneflow import *
122 changes: 122 additions & 0 deletions python/oneflow/framework/docstr/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow
from oneflow.framework.docstr.utils import add_docstr

add_docstr(
oneflow.einsum,
"""
einsum(equation, *operands) -> oneflow.Tensor

Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
based on the Einstein summation convention.

Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
with some subscript and define which subscripts are part of the output. The output is then computed by summing
the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
output. For example, matrix multiplication can be computed using einsum as `flow.einsum("ij,jk->ik", A, B)`.
Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).

Equation:

The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
the input :attr:`operands` in the same order as the dimensions, separating subcripts for each operand by a
comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.

Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
followed by the subscripts for the output. For instance, the following equation computes the transpose of a
matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
at most once for the output.

Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
batch matrix multiplication `'...ij,...jk'`.

A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.

.. note::

``flow.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.

.. note::

This function does not optimize the given expression, so a different formula for the same computation may
run faster or consume less memory. Projects like opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/)
can optimize the formula for you.

Args:
equation (String): The subscripts for the Einstein summation.
*operands (oneflow.Tensor): The tensors to compute the Einstein summation of.

For example:

.. code-block:: python

>>> import oneflow as flow

# trace
>>> flow.einsum('ii', flow.arange(4*4).reshape(4,4).to(flow.float32))
tensor(30., dtype=oneflow.float32)

# diagonal
>>> flow.einsum('ii->i', flow.arange(4*4).reshape(4,4).to(flow.float32))
tensor([ 0., 5., 10., 15.], dtype=oneflow.float32)

# outer product
>>> x = flow.arange(5).to(flow.float32)
>>> y = flow.arange(4).to(flow.float32)
>>> flow.einsum('i,j->ij', x, y)
tensor([[ 0., 0., 0., 0.],
[ 0., 1., 2., 3.],
[ 0., 2., 4., 6.],
[ 0., 3., 6., 9.],
[ 0., 4., 8., 12.]], dtype=oneflow.float32)

# batch matrix multiplication
>>> As = flow.arange(3*2*5).reshape(3,2,5).to(flow.float32)
>>> Bs = flow.arange(3*5*4).reshape(3,5,4).to(flow.float32)
>>> flow.einsum('bij,bjk->bik', As, Bs).shape
oneflow.Size([3, 2, 4])

# batch permute
>>> A = flow.randn(2, 3, 4, 5)
>>> flow.einsum('...ij->...ji', A).shape
oneflow.Size([2, 3, 5, 4])

# bilinear
>>> A = flow.randn(3,5,4)
>>> l = flow.randn(2,5)
>>> r = flow.randn(2,4)
>>> flow.einsum('bn,anm,bm->ba', l, A, r).shape
oneflow.Size([2, 3])

""",
)
26 changes: 26 additions & 0 deletions python/oneflow/nn/modules/einsum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow


def einsum_op(equation, *operands):
return flow._C.einsum(equation, operands)


if __name__ == "__main__":
import doctest

doctest.testmod(raise_on_error=True)
Loading