Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add optimization for imperative mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Jul 16, 2019
1 parent aea6a05 commit cdfaf1c
Show file tree
Hide file tree
Showing 10 changed files with 1,343 additions and 104 deletions.
53 changes: 53 additions & 0 deletions benchmark/python/einsum/benchmark_einsum.py
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
from mxnet import np, npx

def measure_cost(repeat, func_name, *args, **kwargs):
"""Measure time cost of running a function
"""
mx.nd.waitall()
start = time.time()
for _ in range(repeat):
func_name(*args, **kwargs)
mx.nd.waitall()
end = time.time()
diff = end - start
return diff / repeat


def test_np_einsum():
# Basic einsum
a = np.ones(64).reshape(2,4,8)
args = ['ijk,ilm,njm,nlk,abc->', a, a, a, a, a]
cost = measure_cost(500, np.einsum, *args)
print("Basic einsum: {} ms".format(cost * 1000))

# Sub-optimal einsum
cost = measure_cost(500, np.einsum, *args, optimize='optimal')
print("Optimal einsum: {} ms".format(cost * 1000))

# Greedy einsum
cost = measure_cost(500, np.einsum, *args, optimize='greedy')
print("Greedy einsum: {} ms".format(cost * 1000))


if __name__ == "__main__":
npx.set_np()
test_np_einsum()
38 changes: 33 additions & 5 deletions python/mxnet/ndarray/numpy/_op.py
Expand Up @@ -27,6 +27,7 @@
from ...context import current_context
from . import _internal as _npi
from ..ndarray import NDArray
from ...numpy_utils import _einsum_path_util

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
Expand Down Expand Up @@ -1825,9 +1826,9 @@ def arcsin(x, out=None, **kwargs):


@set_module('mxnet.ndarray.numpy')
def einsum(subscripts, *operands, **kwargs):
def einsum(*operands, **kwargs):
r"""
einsum(subscripts, *operands, out=None)
einsum(subscripts, *operands, out=None, optimize=False)
Evaluates the Einstein summation convention on the operands.
Expand All @@ -1853,6 +1854,10 @@ def einsum(subscripts, *operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
Returns
-------
Expand Down Expand Up @@ -1919,7 +1924,16 @@ def einsum(subscripts, *operands, **kwargs):
When there is only one operand, no axes are summed, and no output
parameter is provided, a view into the operand is returned instead
of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
produces a view (changed in version 1.10.0).
produces a view.
The ``optimize`` argument which will optimize the contraction order
of an einsum expression. For a contraction with three or more operands this
can greatly increase the computational efficiency at the cost of a larger
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
Examples
--------
Expand Down Expand Up @@ -2020,6 +2034,20 @@ def einsum(subscripts, *operands, **kwargs):
>>> np.einsum('k...,jk', a, b)
array([[10., 28., 46., 64.],
[13., 40., 67., 94.]])
Chained array operations. For more complicated contractions, speed ups
might be achieved by repeatedly computing a 'greedy' path. Performance
improvements can be particularly significant with larger arrays:
>>> a = np.ones(64).reshape(2,4,8)
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Optimal `einsum`: ~0.672ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
"""
out = kwargs.get('out', None)
return _npi.einsum(*operands, subscripts=subscripts, out=out)
return _einsum_path_util._einsum('ndarray', *operands, **kwargs)
38 changes: 22 additions & 16 deletions python/mxnet/numpy/multiarray.py
Expand Up @@ -3007,9 +3007,9 @@ def arcsin(x, out=None, **kwargs):


@set_module('mxnet.numpy')
def einsum(subscripts, *operands, **kwargs):
def einsum(*operands, **kwargs):
r"""
einsum(subscripts, *operands, out=None)
einsum(subscripts, *operands, out=None, optimize=False)
Evaluates the Einstein summation convention on the operands.
Expand All @@ -3035,6 +3035,10 @@ def einsum(subscripts, *operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
Returns
-------
Expand Down Expand Up @@ -3101,7 +3105,16 @@ def einsum(subscripts, *operands, **kwargs):
When there is only one operand, no axes are summed, and no output
parameter is provided, a view into the operand is returned instead
of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
produces a view (changed in version 1.10.0).
produces a view.
The ``optimize`` argument which will optimize the contraction order
of an einsum expression. For a contraction with three or more operands this
can greatly increase the computational efficiency at the cost of a larger
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
Examples
--------
Expand Down Expand Up @@ -3204,25 +3217,18 @@ def einsum(subscripts, *operands, **kwargs):
[13., 40., 67., 94.]])
Chained array operations. For more complicated contractions, speed ups
might be achieved by repeatedly computing a 'greedy' path or pre-computing the
'optimal' path and repeatedly applying it, using an
`einsum_path` insertion (since version 1.12.0). Performance improvements can be
particularly significant with larger arrays:
might be achieved by repeatedly computing a 'greedy' path. Performance
improvements can be particularly significant with larger arrays:
>>> a = np.ones(64).reshape(2,4,8)
# Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
# Optimal `einsum`: ~0.672ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~160ms
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
# Optimal `einsum` (best usage pattern in some use cases): ~110ms
>>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
"""
out = kwargs.get('out', None)
return _mx_nd_np.einsum(subscripts, *operands, out=out)
return _mx_nd_np.einsum(*operands, **kwargs)
18 changes: 18 additions & 0 deletions python/mxnet/numpy_utils/__init__.py
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Module for numpy_utils"""

0 comments on commit cdfaf1c

Please sign in to comment.