Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enable symbolic backward optimization with einsum_path
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Jul 24, 2019
1 parent 19aa577 commit 7bcdcb3
Show file tree
Hide file tree
Showing 9 changed files with 215 additions and 53 deletions.
10 changes: 5 additions & 5 deletions benchmark/python/einsum/benchmark_einsum.py
Expand Up @@ -41,11 +41,11 @@ def test_np_einsum():
print("Basic einsum: {} ms".format(cost * 1000))

# Sub-optimal einsum
cost = measure_cost(500, np.einsum, *args, optimize='optimal')
print("Optimal einsum: {} ms".format(cost * 1000))
# cost = measure_cost(500, np.einsum, *args, optimize='optimal')
# print("Optimal einsum: {} ms".format(cost * 1000))

# Greedy einsum
cost = measure_cost(500, np.einsum, *args, optimize='greedy')
cost = measure_cost(500, np.einsum, *args, optimize=True)
print("Greedy einsum: {} ms".format(cost * 1000))

print('Inner Product:')
Expand All @@ -55,7 +55,7 @@ def test_np_einsum():
cost = measure_cost(50, np.tensordot, *args, axes=([0],[0]))
print('Tensordot: {} ms'.format(cost * 1000))
args = ['i, i', a, b]
cost = measure_cost(50, np.einsum, *args, optimize='greedy')
cost = measure_cost(50, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(50, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))
Expand All @@ -67,7 +67,7 @@ def test_np_einsum():
cost = measure_cost(50, np.tensordot, *args, axes=([1],[0]))
print('Tensordot: {} ms'.format(cost * 1000))
args = ['ij, jk', a, b]
cost = measure_cost(50, np.einsum, *args, optimize='greedy')
cost = measure_cost(50, np.einsum, *args, optimize=True)
print('Greedy einsum: {} ms'.format(cost * 1000))
cost = measure_cost(50, np.einsum, *args)
print('Basic einsum: {} ms'.format(cost * 1000))
Expand Down
16 changes: 6 additions & 10 deletions python/mxnet/ndarray/numpy/_op.py
Expand Up @@ -1934,10 +1934,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -2012,8 +2011,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
Examples
--------
Expand Down Expand Up @@ -2123,11 +2122,8 @@ def einsum(*operands, **kwargs):
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Optimal `einsum`: ~0.672ms
# Greedy `einsum` (faster optimal path approximation): ~0.117ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True)
"""
return _einsum_path_util._einsum('ndarray', *operands, **kwargs)
16 changes: 6 additions & 10 deletions python/mxnet/numpy/multiarray.py
Expand Up @@ -3115,10 +3115,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : ndarray, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -3193,8 +3192,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
Examples
--------
Expand Down Expand Up @@ -3304,11 +3303,8 @@ def einsum(*operands, **kwargs):
# Basic `einsum`: ~42.22ms (benchmarked on 3.4GHz Intel Xeon.)
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
# Optimal `einsum`: ~0.672ms
# Greedy `einsum` (faster optimal path approximation): ~0.117ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
# Greedy `einsum` (faster optimal path approximation): ~0.306ms
>>> for iteration in range(500):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True)
"""
return _mx_nd_np.einsum(*operands, **kwargs)
9 changes: 4 additions & 5 deletions python/mxnet/symbol/numpy/_symbol.py
Expand Up @@ -2388,10 +2388,9 @@ def einsum(*operands, **kwargs):
These are the arrays for the operation.
out : _Symbol, optional
If provided, the calculation is done into this array.
optimize : {False, True, 'greedy', 'optimal'}, optional
optimize : {False, True}, optional
Controls if intermediate optimization should occur. No optimization
will occur if False and True will default to the 'greedy' algorithm.
Defaults to False.
will occur if False.
Returns
-------
Expand Down Expand Up @@ -2466,8 +2465,8 @@ def einsum(*operands, **kwargs):
memory footprint during computation.
Typically a 'greedy' algorithm is applied which empirical tests have shown
returns the optimal path in the majority of cases. In some cases 'optimal'
will return the superlative path through a more expensive, exhaustive search.
returns the optimal path in the majority of cases. 'optimal' is not supported
for now.
"""
return _einsum_path_util._einsum('symbol', *operands, **kwargs)

Expand Down
164 changes: 146 additions & 18 deletions src/operator/numpy/np_einsum_op-inl.h
Expand Up @@ -20,7 +20,7 @@
/*
* Copyright (c) 2005-2019, NumPy Developers.
* All rights reserved.
*
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
Expand Down Expand Up @@ -61,6 +61,7 @@
#include <mxnet/operator_util.h>
#include <string>
#include <vector>
#include <algorithm>
#include "./np_tensordot_op-inl.h"
#include "./np_einsum_path_op-inl.h"
#include "../../common/static_array.h"
Expand Down Expand Up @@ -399,6 +400,20 @@ struct NumpyEinsumParam: public dmlc::Parameter<NumpyEinsumParam> {
}
};

class EinsumOp {
public:
int num_args;
int optimize;
std::string subscripts;
std::shared_ptr<NDArray> tempspace;
std::vector<Step> paths;
explicit EinsumOp(int num_args, int optimize, std::string subscripts) {
this->num_args = num_args;
this->optimize = optimize;
this->subscripts = subscripts;
}
}; // class EinsumOp

template<int dimension, int req, bool back>
struct numpy_einsum {
template<typename DType>
Expand Down Expand Up @@ -751,28 +766,28 @@ inline void NumpyEinsumProcess(const std::vector<TBlob>& inputs,
}

template<typename xpu>
inline void NumpyEinsumForward(const nnvm::NodeAttrs& attrs,
inline void NumpyEinsumForward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyEinsumParam &param = nnvm::get<NumpyEinsumParam>(attrs.parsed);
int num_args = param.num_args;
int optimize = param.optimize;
const char* subscripts = param.subscripts.c_str();
EinsumOp& state = state_ptr.get_state<EinsumOp>();
int num_args = state.num_args;
int optimize = state.optimize;
const char* subscripts = state.subscripts.c_str();
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), num_args);
CHECK_EQ(outputs.size(), 1U);
if (optimize == 0) {
NumpyEinsumProcess<xpu, 0>(inputs, req, outputs, subscripts, num_args, ctx);
return;
}
std::vector<Step> paths;
std::vector<Step>& paths = state.paths;
std::vector<std::vector<int> > pos;
std::string string_repr;
paths = einsum_path(param.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
paths = einsum_path(state.subscripts, inputs, true, ctx.run_ctx, &pos, &string_repr);
size_t paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
std::vector<TBlob> operands(inputs), tmp_operands, temp_space_vec(paths_len - 1);
for (int i = 0; i < paths_len - 1; ++i) {
Expand All @@ -783,8 +798,11 @@ inline void NumpyEinsumForward(const nnvm::NodeAttrs& attrs,
}
temp_space_size += max_temp_space_size;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Tensor<xpu, 1, DType> temp_space =
ctx.requested[2].get_space_typed<xpu, 1, DType>(Shape1(temp_space_size), s);
state.tempspace.reset<NDArray>(new NDArray(TShape(Shape1(temp_space_size)),
ctx.run_ctx.ctx,
false,
outputs[0].type_flag_));
Tensor<xpu, 1, DType> temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
size_t begin = max_temp_space_size;
for (int i = 0; i < paths_len - 1; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
Expand All @@ -795,7 +813,7 @@ inline void NumpyEinsumForward(const nnvm::NodeAttrs& attrs,
tmp_operands.clear();

// We remove inds from right to left
for (const int& p: paths[i].contract_inds) {
for (const int& p : paths[i].contract_inds) {
tmp_operands.push_back(operands[p]);
operands.erase(operands.begin() + p);
}
Expand Down Expand Up @@ -826,7 +844,7 @@ inline void NumpyEinsumForward(const nnvm::NodeAttrs& attrs,
tmp_operands[0],
tmp_operands[1],
temp_space_vec[i],
std::vector<OpReqType>{OpReqType::kWriteTo});
std::vector<OpReqType>{OpReqType::kWriteTo});
}
} else {
NumpyEinsumProcess<xpu, 0>(tmp_operands,
Expand All @@ -837,23 +855,133 @@ inline void NumpyEinsumForward(const nnvm::NodeAttrs& attrs,
if (!handle_out)
operands.push_back(temp_space_vec[i]);
}
})
});
}

template<typename xpu>
inline void NumpyEinsumBackward(const nnvm::NodeAttrs& attrs,
inline void NumpyEinsumBackward(const OpStatePtr& state_ptr,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow_op;
const NumpyEinsumParam &param = nnvm::get<NumpyEinsumParam>(attrs.parsed);
int num_args = param.num_args;
const char* subscripts = param.subscripts.c_str();
const EinsumOp& state = state_ptr.get_state<EinsumOp>();
int num_args = state.num_args;
int optimize = state.optimize;
const char* subscripts = state.subscripts.c_str();
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 1 + num_args);
CHECK_EQ(outputs.size(), num_args);
NumpyEinsumProcess<xpu, 1>(inputs, req, outputs, subscripts, num_args, ctx);
if (optimize == 0) {
NumpyEinsumProcess<xpu, 1>(inputs, req, outputs, subscripts, num_args, ctx);
return;
}
// calculate temporary space size for temp_grad
const std::vector<Step>& paths = state.paths;
size_t paths_len = paths.size(), temp_space_size = 0, max_temp_space_size = 0;
for (int i = 0; i < paths_len - 1; ++i) {
temp_space_size += paths[i].oshape.Size();
}
for (int i = 0; i < paths_len; ++i) {
max_temp_space_size = std::max(max_temp_space_size, paths[i].oshape.Size());
}
temp_space_size += max_temp_space_size;
// replay the forward process
std::vector<std::vector<int> > op_idx(paths_len + 1);
for (size_t i = 0; i <= paths_len; ++i) {
if (i == 0) {
op_idx[i].reserve(num_args);
for (int j = 0; j < num_args; ++j) {
op_idx[i].push_back(j + 1);
}
} else {
op_idx[i] = op_idx[i - 1];
// We remove inds from right to left
for (const int& p : paths[i - 1].contract_inds) {
op_idx[i].erase(op_idx[i].begin() + p);
}
op_idx[i].push_back(-static_cast<int>(i - 1));
}
}

// allocate temporary space and propagate
std::vector<TBlob> temp_grad(paths_len - 1), temp_data(paths_len - 1);
std::vector<TBlob> temp_inputs, temp_outputs;
std::vector<OpReqType> temp_req;
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
// allocate temporary space for gradients of intermediate results
Tensor<xpu, 1, DType> temp_space = ctx.requested[2].get_space_typed<xpu, 1, DType>
(Shape1(temp_space_size), s);
size_t begin = max_temp_space_size;
for (size_t i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_grad[i] = tblob.reshape(paths[i].oshape);
begin = begin + paths[i].oshape.Size();
}

// reinterprete ndarray for intermediate results
temp_space = state.tempspace->data().FlatTo1D<xpu, DType>();
begin = max_temp_space_size;
for (size_t i = 0; i + 1 < paths_len; ++i) {
TBlob tblob = TBlob(temp_space.Slice(begin, begin + paths[i].oshape.Size()));
temp_data[i] = tblob.reshape(paths[i].oshape);
begin = begin + paths[i].oshape.Size();
}

// go through the paths in the reversed order
for (int i = paths_len - 1; i >= 0; i--) {
temp_inputs.clear();
temp_outputs.clear();
temp_req.clear();
bool handle_out = (i == paths_len - 1);

if (handle_out) {
temp_inputs.push_back(inputs[0]);
} else {
temp_inputs.push_back(temp_grad[i]);
}
for (auto p : paths[i].contract_inds) {
int idx = op_idx[i][p];
if (idx >= 1) {
temp_inputs.push_back(inputs[idx]);
temp_outputs.push_back(outputs[idx - 1]);
temp_req.push_back(req[idx - 1]);
} else {
temp_inputs.push_back(temp_data[-idx]);
temp_outputs.push_back(temp_grad[-idx]);
temp_req.push_back(OpReqType::kWriteTo);
}
}

if (paths[i].do_blas) {
CHECK_EQ(temp_inputs.size(), 3U);
CHECK_EQ(temp_outputs.size(), 2U);
CHECK_EQ(temp_req.size(), 2U);
if (paths[i].do_einsum) {
TBlob max_temp_space = TBlob(temp_space.Slice(0, paths[i].tshape.Size()));
max_temp_space = max_temp_space.reshape(paths[i].tshape);
NumpyEinsumProcess<xpu, 0>(std::vector<TBlob>{temp_inputs[0]},
std::vector<OpReqType>{kWriteTo},
std::vector<TBlob>{max_temp_space},
paths[i].einsum2blas_str.c_str(),
1, ctx);
TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
max_temp_space, temp_inputs[1], temp_inputs[2],
temp_outputs[0], temp_outputs[1], temp_req);
} else {
TensordotBackwardImpl<xpu>(paths[i].left_pos, paths[i].right_pos, ctx,
temp_inputs[0], temp_inputs[1], temp_inputs[2],
temp_outputs[0], temp_outputs[1], temp_req);
}
} else {
NumpyEinsumProcess<xpu, 1>(temp_inputs, temp_req, temp_outputs,
paths[i].einsum_str.c_str(),
temp_outputs.size(),
ctx);
}
}
});
}

} // namespace op
Expand Down

0 comments on commit 7bcdcb3

Please sign in to comment.