Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enable symbolic backward optimization with einsum_path
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Aug 15, 2019
1 parent b8a70c2 commit 3255b3e
Show file tree
Hide file tree
Showing 12 changed files with 250 additions and 80 deletions.
10 changes: 5 additions & 5 deletions benchmark/python/einsum/benchmark_einsum.py
Expand Up @@ -41,11 +41,11 @@ def test_np_einsum():
print("Basic einsum: {} ms".format(cost * 1000))

# Sub-optimal einsum
cost = measure_cost(500, np.einsum, *args, optimize='optimal')
print("Optimal einsum: {} ms".format(cost * 1000))
# cost = measure_cost(500, np.einsum, *args, optimize='optimal')
# print("Optimal einsum: {} ms".format(cost * 1000))

# Greedy einsum
cost = measure_cost(500, np.einsum, *args, optimize='greedy')
cost = measure_cost(500, np.einsum, *args, optimize=True)
print("Greedy einsum: {} ms".format(cost * 1000))

print('Inner Product:')
Expand All @@ -55,7 +55,7 @@ def test_np_einsum():
cost = measure_cost(50, np.tensordot, *args, axes=([0],[0]))
print('Tensordot: {} ms'.format(cost * 1000))
args = ['i, i', a, b]
cost = measure_cost(50, np.einsum, *args, optimize='greedy')
cost = measure_cost(50, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(50, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))
Expand All @@ -67,7 +67,7 @@ def test_np_einsum():
cost = measure_cost(50, np.tensordot, *args, axes=([1],[0]))
print('Tensordot: {} ms'.format(cost * 1000))
args = ['ij, jk', a, b]
cost = measure_cost(50, np.einsum, *args, optimize='greedy')
cost = measure_cost(50, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(50, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))
Expand Down
16 changes: 6 additions & 10 deletions python/mxnet/ndarray/numpy/_op.py
Expand Up @@ -737,10 +737,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -815,8 +814,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
Examples
--------
Expand Down Expand Up @@ -926,11 +925,8 @@ def einsum(*operands, **kwargs):
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Optimal `einsum`: ~0.672ms
# Greedy `einsum` (faster optimal path approximation): ~0.117ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True)
"""
return _einsum_path_util._einsum('ndarray', *operands, **kwargs)
16 changes: 6 additions & 10 deletions python/mxnet/numpy/multiarray.py
Expand Up @@ -1908,10 +1908,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -1986,8 +1985,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
Examples
--------
Expand Down Expand Up @@ -2097,11 +2096,8 @@ def einsum(*operands, **kwargs):
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Optimal `einsum`: ~0.672ms
# Greedy `einsum` (faster optimal path approximation): ~0.117ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True)
"""
return _mx_nd_np.einsum(*operands, **kwargs)
9 changes: 4 additions & 5 deletions python/mxnet/symbol/numpy/_symbol.py
Expand Up @@ -1365,10 +1365,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : _Symbol, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -1443,8 +1442,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
"""
return _einsum_path_util._einsum('symbol', *operands, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/imperative/imperative_utils.h
Expand Up @@ -274,7 +274,7 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs,
}
}
// relax the constraint for einsum (which needs 3) and tensordot (which needs 2)
CHECK_LE(ntmp, 3) << "Only support 3 temp space requests";
CHECK_LE(ntmp, 4) << "Only support 4 temp space requests";
}

// append extra resource requests for storage fallback
Expand Down
4 changes: 2 additions & 2 deletions src/operator/numpy/np_dot.cc
Expand Up @@ -129,7 +129,7 @@ NNVM_REGISTER_OP(_np_dot)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
return std::vector<ResourceRequest>(2, ResourceRequest::kTempSpace);
})
.set_attr<FCompute>("FCompute<cpu>", NumpyDotForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_np_dot"})
Expand All @@ -142,7 +142,7 @@ NNVM_REGISTER_OP(_backward_np_dot)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
return std::vector<ResourceRequest>(2, ResourceRequest::kTempSpace);
})
.set_attr<FCompute>("FCompute<cpu>", NumpyDotBackward<cpu>);

Expand Down

0 comments on commit 3255b3e

Please sign in to comment.